Shape
形状推断函数(Shape Inference Function)。该函数根据输入张量的形状和算子参数,推断输出张量的形状。该函数不区分数据类型,只处理张量的形状信息。
如果所有输出张量都是常量(ConstTensor 或 ConstScalar),则直接返回,不进行形状推断。否则,根据算子类型调用相应的形状推断函数。
- 支持的算子类型:
Arithmetic_InferShape - 算术运算的形状推断
Common_InferShape - 通用算子的形状推断
Softmax_InferShape - Softmax 算子的形状推断
MaxMinGrad_InferShape - MaxMin 梯度算子的形状推断
Dropout_InferShape - Dropout 算子的形状推断
DynamicQuant_InferShape - 动态量化算子的形状推断
Fft_InferShape - FFT 算子的形状推断
Flatten_InferShape - Flatten 算子的形状推断
LayerNorm_InferShape - LayerNorm 算子的形状推断
LogSoftmax_InferShape - LogSoftmax 算子的形状推断
- 输入:
inputs - 输入张量数组(TensorC** 类型)。
inputs_size - 输入张量的数量。
outputs - 输出张量数组(TensorC** 类型)。
outputs_size - 输出张量的数量。
param - 算子参数(OpParameter* 类型),包含算子类型和其他参数信息。
- 输出:
outputs - 输出张量数组,其中的形状信息会被更新。
- 支持平台:
FT78NEMT7004
备注
该函数不区分数据类型,适用于所有数据类型
函数会自动检查输出是否为常量,如果是常量则跳过形状推断
共享存储/私有存储版本:
-
void shape(TensorC **inputs, int inputs_size, TensorC **outputs, int outputs_size, OpParameter *param)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <shape.h>
4
5int main(int argc, char* argv[]) {
6 TensorC** input_tensors_ptrs = (TensorC**)0x10010000;
7 TensorC** output_tensors_ptrs = (TensorC**)0x10011000;
8
9 TensorC input0;
10 TensorC input1;
11 TensorC output;
12
13 int input0_shape[4] = {1,2,3,4};
14 int input1_shape[4] = {1,3,4};
15 int output_shape[4]; //不用初始化
16 memcpy(input0.shape_, input0_shape, 4 * sizeof(int));
17 input0.shape_size_ = 4;
18 memcpy(input1.shape_, input1_shape, 4 * sizeof(int));
19 input1.shape_size_ = 3;
20 input0.data_type_ = kNumberTypeFloat32;
21 input1.data_type_ = kNumberTypeFloat32;
22 input0.format_ = Format_NCHW;
23 input1.format_ = Format_NCHW;
24
25 input_tensors_ptrs[0] = &input0;
26 input_tensors_ptrs[1] = &input1;
27 output_tensors_ptrs[0] = &output;
28
29 ArithmeticParameter param;
30 param.op_parameter_.type_ = Arithmetic_InferShape;
31
32 shape(input_tensors_ptrs, 2, output_tensors_ptrs, 1, (OpParameter*)¶m);
33 return 0;
34}